from .detect import FastDetectGPT as _FastDetectGPT

class AIDetector:
    """
    A wrapper for detecting machine-generated text using FastDetectGPT.
    """
    def __init__(self,
                 model_name="meta-llama/Llama-3.1-8B",
                 reference_model_name=None,
                 device="cuda"):
        """
        Initialize the AI text detector.

        Args:
            model_name (str): Name of the scoring model to use
            reference_model_name (str, optional): Name of the reference model.
                                                  Defaults to the scoring model if None.
            device (str): Device to run the detection on. Default is "cuda".
        """
        self._detector = _FastDetectGPT(
            model_name=model_name,
            reference_model_name=reference_model_name,
            device=device
        )

    def detect(self, text, max_length=2048):
        """
        Detect if a given text is machine-generated.

        Args:
            text (str): The text to analyze
            max_length (int, optional): Maximum token length to process. Defaults to 2048.

        Returns:
            dict: Detection results with 'criterion' and 'probability'
        """
        return self._detector.detect(text, max_length)

    def detect_batch(self, texts, max_length=2048):
        """
        Detect if multiple texts are machine-generated.

        Args:
            texts (list): List of texts to analyze
            max_length (int, optional): Maximum token length to process per text. Defaults to 2048.

        Returns:
            list: List of detection results for each text
        """
        return self._detector.detect_batch(texts, max_length)

    def analyze_paper(self, paper):
        """
        Analyze a research paper to determine if it's machine-generated.

        Args:
            paper (dict): Research paper generated by CycleResearcher

        Returns:
            dict: Comprehensive detection analysis
        """
        # Combine different sections of the paper for comprehensive analysis
        text_to_analyze = (
            f"{paper.get('title', '')} "
            f"{paper.get('abstract', '')} "
            f"{paper.get('latex', '')}"
        )

        # Perform detection
        detection_result = self.detect(text_to_analyze)

        # Enhance the result with more context
        return {
            "criterion": detection_result['criterion'],
            "probability": detection_result['probability'],
            "is_likely_ai_generated": detection_result['probability'] > 0.5,
            "confidence_level": self._get_confidence_level(detection_result['probability'])
        }

    def _get_confidence_level(self, probability):
        """
        Convert probability to a descriptive confidence level.

        Args:
            probability (float): Probability of being machine-generated

        Returns:
            str: Confidence level description
        """
        if probability < 0.3:
            return "Low likelihood of AI generation"
        elif probability < 0.5:
            return "Moderate likelihood of AI generation"
        elif probability < 0.7:
            return "High likelihood of AI generation"
        else:
            return "Very high likelihood of AI generation"

def detect_paper(paper,
                 model_name="meta-llama/Llama-3.1-8B",
                 device="cuda"):
    """
    Convenience function to quickly detect if a paper is machine-generated.

    Args:
        paper (dict): Research paper to analyze
        model_name (str): Detection model to use
        device (str): Device to run detection on

    Returns:
        dict: Detection analysis results
    """
    detector = AIDetector(model_name=model_name, device=device)
    return detector.analyze_paper(paper)